import os, sys

import environments

# import python packages
import time
from copy import deepcopy

# import third-party packages
import numpy as np
import torch
import torch.optim as optim
from multiprocessing import Process, Queue, Event
import pickle
from queue import Queue as Queue_

# +
# import our packages
from scalarization_methods import WeightedSumScalarization
from sample import Sample
from task import Task
from ep import EP
from population_2d import Population as Population2d
from population_3d import Population as Population3d
from population_4d import Population as Population4d
from opt_graph import OptGraph
from utils import generate_weights_batch_dfs, print_info
from warm_up import initialize_warm_up_batch
from p_test import initialize_2D_test_batch, initialize_3D_test_batch, initialize_4D_test_batch
from mopg import MOPG_worker, MOPG_test, MOPG_test_gen_rollout, pgmorl_MOPG_worker
from queue2list import Queue2List
test_iter = 20

def run(args):

    # --------------------> Preparation <-------------------- #
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_default_dtype(torch.float64)
    torch.set_num_threads(1)
    device = torch.device("cpu")
    

    
    # build a scalarization template
    scalarization_template = WeightedSumScalarization(num_objs = args.obj_num, weights = np.ones(args.obj_num) / args.obj_num)

    total_num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes
    start_time = time.time()

    # initialize ep and population and opt_graph
    ep = EP()
    if args.obj_num == 2:
        population = Population2d(args)
    elif args.obj_num == 3:
        population = Population3d(args)
    elif args.obj_num == 4:
        population = Population4d(args)
    else:
        raise NotImplementedError
    opt_graph = OptGraph()
    
    # Construct tasks for warm up
    elite_batch, scalarization_batch = initialize_warm_up_batch(args, device)
    Task_num = len(elite_batch)
    if args.obj_num == 3:
        elite_batch_test, scalarization_batch_test = initialize_3D_test_batch(args, device)
    elif args.obj_num == 4:
        elite_batch_test, scalarization_batch_test = initialize_4D_test_batch(args, device)
    elif args.selection_method == 'pfa':
        elite_batch_test, scalarization_batch_test = initialize_2D_test_batch(args, device, 0, 1, 0.01)
        pfa_last_offspring_batch = []
    elif args.selection_method == 'prediction-guided':
        elite_batch_test, scalarization_batch_test = initialize_2D_test_batch(args, device, 0, 1, 0.01)
        pgmorl_last_offspring_batch = []
    rl_num_updates = args.warmup_iter
    for sample, scalarization in zip(elite_batch, scalarization_batch):
        sample.optgraph_id = opt_graph.insert(deepcopy(scalarization.weights), deepcopy(sample.objs), -1)
    
    Task_queue = Queue_(maxsize=Task_num * 10)
    Weight_queue = Queue_(maxsize=Task_num * 10)
    
    episode = 0
    iteration = 0
    all_sample_batch = [] 
    offspring_batch = [] 
    while iteration < total_num_updates:
        print(iteration,rl_num_updates)
        if episode == 0 :
            print_info('\n------------------------------- Warm-up Stage -------------------------------')    
        else:
            print_info('\n-------------------- Evolutionary Stage: Generation {:3} --------------------'.format(episode))

        episode += 1
        

        # --------------------> RL Optimization <-------------------- #
        # compose task for each elite
        task_batch = []
        for elite, scalarization in \
                zip(elite_batch, scalarization_batch):
            task_batch.append(Task(elite, scalarization)) # each task is a (policy, weight) pair
        # run MOPG for each task in parallel
        processes = []
        results_queue = Queue()
        done_event = Event()
        
        for task_id, task in enumerate(task_batch):
            p = Process(target = MOPG_worker, \
                args = (args, task_id, task, device, iteration, rl_num_updates, start_time, results_queue, done_event))
            p.start()
            processes.append(p)

        # collect MOPG results for offsprings and insert objs into objs buffer
        all_offspring_batch = [[] for _ in range(len(processes))]
        cnt_done_workers = 0
        while cnt_done_workers < len(processes):
            rl_results = results_queue.get()
            task_id, offsprings = rl_results['task_id'], rl_results['offspring_batch']
            for sample in offsprings:
                all_offspring_batch[task_id].append(Sample.copy_from(sample))
            if rl_results['done']:
                cnt_done_workers += 1
        # put all intermidiate policies into all_sample_batch for EP update
        all_sample_batch = [] 
        # store the last policy for each optimization weight for RA
        last_offspring_batch = [None] * len(processes) 
        offspring_batch = [] 
        for task_id in range(len(processes)):
            offsprings = all_offspring_batch[task_id]
            prev_node_id = task_batch[task_id].sample.optgraph_id
            opt_weights = deepcopy(task_batch[task_id].scalarization.weights).detach().numpy()
            for i, sample in enumerate(offsprings):
                all_sample_batch.append(sample)
                if (i + 1) % args.update_iter == 0 :
                    prev_node_id = opt_graph.insert(opt_weights, deepcopy(sample.objs), prev_node_id)
                    sample.optgraph_id = prev_node_id
                    offspring_batch.append(sample)
            last_offspring_batch[task_id] = offsprings[-1]
        done_event.set()
        
        # -----------------------> Update EP <----------------------- #
        # update EP and population
        
        ep.update(all_sample_batch)
        population.update(offspring_batch)

        # ------------------- > Task Selection <--------------------- #
        if args.selection_method == 'moead':
            elite_batch, scalarization_batch = [], []
            weights_batch = []
            generate_weights_batch_dfs(0, args.obj_num, args.min_weight, args.max_weight, args.delta_weight, [], weights_batch)
            for weights in weights_batch:
                scalarization = deepcopy(scalarization_template)
                scalarization.update_weights(weights)
                scalarization_batch.append(scalarization)
                best_sample, best_value = None, -np.inf
                for sample in population.sample_batch:
                    value = scalarization.evaluate(torch.Tensor(sample.objs))
                    if value > best_value:
                        best_sample, best_value = sample, value
                elite_batch.append(best_sample)
        elif args.selection_method == 'prediction-guided':
            elite_batch, scalarization_batch, predicted_offspring_objs, task_queue_batch, weights_queue_batch = population.prediction_guided_selection(args, iteration, ep, opt_graph, scalarization_template)
            for queue_i in range(len(task_queue_batch)):
                if Task_queue.qsize() == (Task_num * 10):
                    Task_queue.get()
                    Weight_queue.get()
                Task_queue.put(task_queue_batch[queue_i])
                Weight_queue.put(weights_queue_batch[queue_i])
            if (min(iteration + rl_num_updates, total_num_updates)) % test_iter == 0:
                pgmorl_last_offspring_batch = [None] * (len(elite_batch_test)) 
                pgmorl_results_queue = Queue()
                pgmorl_done_event = Event()
                pgmorl_task_batch = []
                pgmorl_processes = []
                pgmorl_rollout= []
                pgmorl_table = []
                task_queue_list, Task_queue = Queue2List(Task_queue)
                weights_queue_list, Weight_queue = Queue2List(Weight_queue)
                for Task_num_i in range(len(task_queue_list)):
                    pgmorl_rollout.append(MOPG_test_gen_rollout(args, task_id, Task(task_queue_list[Task_num_i], weights_queue_list[Task_num_i]), device, iteration, 1, start_time))
                for elite, scalarization in \
                        zip(elite_batch_test, scalarization_batch_test):
                    Task_num_compare = 1
                    Task_num_index = 0
                    for Task_num_i in range(len(task_queue_list)):
                        temp_compare = torch.norm(scalarization.weights-weights_queue_list[Task_num_i].weights).item()
                        if Task_num_compare > temp_compare:
                            Task_num_compare = temp_compare
                            Task_num_index = Task_num_i
                    pgmorl_task_batch.append(Task(task_queue_list[Task_num_index], scalarization)) # each task is a (policy, weight) pair
                    pgmorl_table.append(Task_num_index)
                for task_id, task in enumerate(pgmorl_task_batch):
                    p = Process(target = MOPG_test, \
                        args = (args, task_id, task, device, iteration, 1, start_time, pgmorl_results_queue, pgmorl_done_event, pgmorl_rollout[pgmorl_table[task_id]]))
                    p.start()
                    pgmorl_processes.append(p)
        
                # collect MOPG results for offsprings and insert objs into objs buffer
                pgmorl_all_offspring_batch = [[] for _ in range(len(pgmorl_processes))]
                cnt_done_workers = 0
                while cnt_done_workers < len(pgmorl_processes):
                    rl_results = pgmorl_results_queue.get()
                    task_id, offsprings = rl_results['task_id'], rl_results['offspring_batch']
                    for sample in offsprings:
                        pgmorl_all_offspring_batch[task_id].append(Sample.copy_from(sample))
                    if rl_results['done']:
                        cnt_done_workers += 1
                # store the last policy for each optimization weight for RA
    
                # only the policies with iteration % update_iter = 0 are inserted into offspring_batch for population update
                # after warm-up stage, it's equivalent to the last_offspring_batch
                for task_id in range(len(pgmorl_processes)):
                    offsprings = pgmorl_all_offspring_batch[task_id]
                    pgmorl_last_offspring_batch[task_id] = offsprings[-1]
                pgmorl_done_event.set()    
        elif args.selection_method == 'random':
            elite_batch, scalarization_batch = population.random_selection(args, scalarization_template)
        elif args.selection_method == 'ra':
            elite_batch = last_offspring_batch
            scalarization_batch = []
            weights_batch = []
            generate_weights_batch_dfs(0, args.obj_num, args.min_weight, args.max_weight, args.delta_weight, [], weights_batch)
            for weights in weights_batch:
                scalarization = deepcopy(scalarization_template)
                scalarization.update_weights(weights)
                scalarization_batch.append(scalarization)
        elif args.selection_method == 'pfa':
            if args.obj_num > 2:
                raise NotImplementedError
            elite_batch = last_offspring_batch
            scalarization_batch = []
            delta_ratio = (iteration + rl_num_updates + args.update_iter - args.warmup_iter) / (total_num_updates - args.warmup_iter)
            delta_ratio = np.clip(delta_ratio, 0.0, 1.0)
            for i in np.arange(args.min_weight, args.max_weight + 0.5 * args.delta_weight, args.delta_weight):
                w = np.clip(i + delta_ratio * args.delta_weight, args.min_weight, args.max_weight)
                weights = np.array([abs(w), abs(1.0 - w)])
                scalarization = deepcopy(scalarization_template)
                scalarization.update_weights(weights)
                scalarization_batch.append(scalarization)
            for queue_i in range(len(elite_batch)):
                if Task_queue.qsize() == (Task_num * 10):
                    Task_queue.get()
                    Weight_queue.get()
                Task_queue.put(elite_batch[queue_i])
                Weight_queue.put(scalarization_batch[queue_i])
            if (min(iteration + rl_num_updates, total_num_updates)) % test_iter == 0:
                pfa_last_offspring_batch = [None] * (len(elite_batch_test)) 
                pfa_results_queue = Queue()
                pfa_done_event = Event()
                pfa_task_batch = []
                pfa_processes = []
                pfa_rollout= []
                pfa_table = []
                task_queue_list, Task_queue = Queue2List(Task_queue)
                weights_queue_list, Weight_queue = Queue2List(Weight_queue)
                for Task_num_i in range(len(task_queue_list)):
                    pfa_rollout.append(MOPG_test_gen_rollout(args, task_id, Task(task_queue_list[Task_num_i], weights_queue_list[Task_num_i]), device, iteration, 1, start_time))
                for elite, scalarization in \
                        zip(elite_batch_test, scalarization_batch_test):
                    Task_num_compare = 1
                    Task_num_index = 0
                    for Task_num_i in range(len(task_queue_list)):
                        temp_compare = torch.norm(scalarization.weights-weights_queue_list[Task_num_i].weights).item()                  
                        if Task_num_compare > temp_compare:
                            Task_num_compare = temp_compare
                            Task_num_index = Task_num_i
                    pfa_task_batch.append(Task(task_queue_list[Task_num_index], scalarization)) # each task is a (policy, weight) pair
                    pfa_table.append(Task_num_index)
                for task_id, task in enumerate(pfa_task_batch):
                    p = Process(target = MOPG_test, \
                        args = (args, task_id, task, device, iteration, 1, start_time, pfa_results_queue, pfa_done_event, pfa_rollout[pfa_table[task_id]]))
                    p.start()
                    pfa_processes.append(p)

                # collect MOPG results for offsprings and insert objs into objs buffer
                pfa_all_offspring_batch = [[] for _ in range(len(pfa_processes))]
                cnt_done_workers = 0
                while cnt_done_workers < len(pfa_processes):
                    rl_results = pfa_results_queue.get()
                    task_id, offsprings = rl_results['task_id'], rl_results['offspring_batch']
                    for sample in offsprings:
                        pfa_all_offspring_batch[task_id].append(Sample.copy_from(sample))
                    if rl_results['done']:
                        cnt_done_workers += 1
                # store the last policy for each optimization weight for RA

                # only the policies with iteration % update_iter = 0 are inserted into offspring_batch for population update
                # after warm-up stage, it's equivalent to the last_offspring_batch
                for task_id in range(len(pfa_processes)):
                    offsprings = pfa_all_offspring_batch[task_id]
                    pfa_last_offspring_batch[task_id] = offsprings[-1]
                pfa_done_event.set()   
        else:
            raise NotImplementedError
        
        print_info('Selected Tasks:')
        for i in range(min(len(elite_batch), 5)):
            print_info('objs = {}, weight = {}'.format(elite_batch[i].objs, scalarization_batch[i].weights))

        iteration = min(iteration + rl_num_updates, total_num_updates)
        
        rl_num_updates = args.update_iter

        # ----------------------> Save Results <---------------------- #
        # save ep
        ep_dir = os.path.join(args.save_dir, str(iteration), 'ep')
        os.makedirs(ep_dir, exist_ok = True)
        with open(os.path.join(ep_dir, 'objs.txt'), 'w') as fp:
            for obj in ep.obj_batch:
                fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*obj))

        # save population
        population_dir = os.path.join(args.save_dir, str(iteration), 'population')
        os.makedirs(population_dir, exist_ok = True)
        with open(os.path.join(population_dir, 'objs.txt'), 'w') as fp:
            for sample in population.sample_batch:
                fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(sample.objs)))
        # save optgraph and node id for each sample in population
        with open(os.path.join(population_dir, 'optgraph.txt'), 'w') as fp:
            fp.write('{}\n'.format(len(opt_graph.objs)))
            for i in range(len(opt_graph.objs)):
                fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + ';{:5f}' + (args.obj_num - 1) * ',{:5f}' + ';{}\n').format(*(opt_graph.weights[i]), *(opt_graph.objs[i]), opt_graph.prev[i]))
            fp.write('{}\n'.format(len(population.sample_batch)))
            for sample in population.sample_batch:
                fp.write('{}\n'.format(sample.optgraph_id))

        # save elites
        elite_dir = os.path.join(args.save_dir, str(iteration), 'elites')
        os.makedirs(elite_dir, exist_ok = True)
        with open(os.path.join(elite_dir, 'elites.txt'), 'w') as fp:
            for elite in elite_batch:
                fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(elite.objs)))
        with open(os.path.join(elite_dir, 'weights.txt'), 'w') as fp:
            for scalarization in scalarization_batch:
                fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(scalarization.weights)))
        if args.selection_method == 'prediction-guided':
            with open(os.path.join(elite_dir, 'predictions.txt'), 'w') as fp:
                for objs in predicted_offspring_objs:
                    fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(objs)))
            with open(os.path.join(elite_dir, 'pgmorl_test.txt'), 'w') as fp:
                for elite in pgmorl_last_offspring_batch:
                    fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(elite.objs)))
        with open(os.path.join(elite_dir, 'offsprings.txt'), 'w') as fp:
            for i in range(len(all_offspring_batch)):
                for j in range(len(all_offspring_batch[i])):
                    fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(all_offspring_batch[i][j].objs)))
        if args.selection_method == 'pfa':
             with open(os.path.join(elite_dir, 'pfa_test.txt'), 'w') as fp:
                for elite in pfa_last_offspring_batch:
                    fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(elite.objs)))
        tmp_i = len(ep.sample_batch) - 1
        os.makedirs(os.path.join(elite_dir, 'model'), exist_ok = True)
        torch.save(ep.sample_batch[tmp_i].actor_critic.state_dict(), os.path.join(elite_dir, 'model', 'EP_policy_{}.pt'.format(tmp_i)))
        with open(os.path.join(elite_dir, 'model', 'EP_env_params_{}.pkl'.format(tmp_i)), 'wb') as fp:
            pickle.dump(ep.sample_batch[tmp_i].env_params, fp)

    # ----------------------> Save Final Model <---------------------- 

    os.makedirs(os.path.join(args.save_dir, 'final'), exist_ok = True)
    
    # save ep policies & env_params
    for i, sample in enumerate(ep.sample_batch):
        torch.save(sample.actor_critic.state_dict(), os.path.join(args.save_dir, 'final', 'EP_policy_{}.pt'.format(i)))
        with open(os.path.join(args.save_dir, 'final', 'EP_env_params_{}.pkl'.format(i)), 'wb') as fp:
            pickle.dump(sample.env_params, fp)
    
    # save all ep objectives
    with open(os.path.join(args.save_dir, 'final', 'objs.txt'), 'w') as fp:
        for i, obj in enumerate(ep.obj_batch):
            fp.write(('{:5f}' + (args.obj_num - 1) * ',{:5f}' + '\n').format(*(obj)))

    # save all ep env_params
    if args.obj_rms:
        with open(os.path.join(args.save_dir, 'final', 'env_params.txt'), 'w') as fp:
            for sample in ep.sample_batch:
                fp.write('obj_rms: mean: {} var: {}\n'.format(sample.env_params['obj_rms'].mean, sample.env_params['obj_rms'].var))

